{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import argparse\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "from segnet import SegNet\n",
    "from loss import DiscriminativeLoss\n",
    "from dataset import tuSimpleDataset\n",
    "from logger import Logger"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### tuSimple dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_CHANNELS = 3\n",
    "OUTPUT_CHANNELS = 2\n",
    "LEARNING_RATE = 1e-4\n",
    "BATCH_SIZE = 16\n",
    "NUM_EPOCHS = 200\n",
    "LOG_INTERVAL = 10\n",
    "SIZE = [224, 224] # vgg16 inputs size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = Logger('../logs')\n",
    "\n",
    "train_path = '/data/tuSimple/train_set/'\n",
    "train_dataset = tuSimpleDataset(train_path, size=SIZE)\n",
    "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SegNet(input_ch=INPUT_CHANNELS, output_ch=OUTPUT_CHANNELS).cuda()\n",
    "\n",
    "if os.path.isfile(\"model_best.pth\"):\n",
    "   print(\"Loaded model_best.pth\")\n",
    "   model.load_state_dict(torch.load(\"model_best.pth\"))\n",
    "\n",
    "criterion_ce = torch.nn.CrossEntropyLoss().cuda()\n",
    "criterion_disc = DiscriminativeLoss(delta_var=0.1,\n",
    "                                   delta_dist=0.6,\n",
    "                                   norm=2,\n",
    "                                   usegpu=True).cuda()\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,40,60,80,100,120,140,160,180], gamma=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    # refer from : https://github.com/Sayan98/pytorch-segnet/blob/master/src/train.py\n",
    "    is_better = True\n",
    "    prev_loss = float('inf')\n",
    "\n",
    "    model.train()\n",
    "\n",
    "    for epoch in range(NUM_EPOCHS):\n",
    "        t_start = time.time()\n",
    "        loss_f = []\n",
    "\n",
    "        for batch_idx, (imgs, sem_labels, ins_labels) in enumerate(train_dataloader):\n",
    "            loss = 0\n",
    "\n",
    "            img_tensor = torch.autograd.Variable(imgs).cuda()\n",
    "            sem_tensor = torch.autograd.Variable(sem_labels).cuda()\n",
    "            ins_tensor = torch.autograd.Variable(ins_labels).cuda()\n",
    "\n",
    "            # Init gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Predictions\n",
    "            sem_pred, ins_pred = model(img_tensor)\n",
    "\n",
    "            # Discriminative Loss\n",
    "            disc_loss = criterion_disc(ins_pred, ins_tensor, [5] * len(img_tensor))\n",
    "            loss += disc_loss\n",
    "\n",
    "            # CrossEntropy Loss\n",
    "            ce_loss = criterion_ce(sem_pred.permute(0,2,3,1).contiguous().view(-1,OUTPUT_CHANNELS), sem_tensor.view(-1))\n",
    "            loss += ce_loss\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            loss_f.append(loss.cpu().data.numpy())\n",
    "\n",
    "            if batch_idx % LOG_INTERVAL == 0:\n",
    "                print('\\tTrain Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                    epoch, batch_idx * len(imgs), len(train_dataloader.dataset),\n",
    "                    100. * batch_idx / len(train_dataloader), loss.item()))\n",
    "\n",
    "                #Tensorboard\n",
    "                info = {'loss': loss.item(), 'disc_loss': disc_loss.item(), 'ce_loss': ce_loss.item()}\n",
    "\n",
    "                for tag, value in info.items():\n",
    "                    logger.scalar_summary(tag, value, batch_idx + 1)\n",
    "\n",
    "                # 2. Log values and gradients of the parameters (histogram summary)\n",
    "                for tag, value in model.named_parameters():\n",
    "                    tag = tag.replace('.', '/')\n",
    "                    logger.histo_summary(tag, value.data.cpu().numpy(), batch_idx + 1)\n",
    "                    # logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), batch_idx + 1)\n",
    "\n",
    "                # 3. Log training images (image summary)\n",
    "                info = {'images': img_tensor.view(-1, 3, 224, 224)[:10].cpu().numpy(),\n",
    "                        'labels': sem_tensor.view(-1, 224, 224)[:10].cpu().numpy(),\n",
    "                        'sem_preds': sem_pred.view(-1, 2, 224, 224)[:10,1].data.cpu().numpy(),\n",
    "                        'ins_preds': ins_pred.view(-1, 224, 224)[:10].data.cpu().numpy()}\n",
    "\n",
    "                for tag, images in info.items():\n",
    "                    logger.image_summary(tag, images, batch_idx + 1)\n",
    "\n",
    "        dt = time.time() - t_start\n",
    "        is_better = np.mean(loss_f) < prev_loss\n",
    "        scheduler.step()\n",
    "\n",
    "        if is_better:\n",
    "            prev_loss = np.mean(loss_f)\n",
    "            print(\"\\t\\tBest Model.\")\n",
    "            torch.save(model.state_dict(), \"model_best.pth\")\n",
    "\n",
    "        print(\"Epoch #{}\\tLoss: {:.8f}\\t Time: {:2f}s, Lr: {:2f}\".format(epoch+1, np.mean(loss_f), dt, optimizer.param_groups[0]['lr']))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
